Skip to content

llama : disable graph reuse with pipeline parallelism#20463

Merged
ggerganov merged 2 commits intomasterfrom
gg/llama-disable-graph-reuse-with-pp
Mar 12, 2026
Merged

llama : disable graph reuse with pipeline parallelism#20463
ggerganov merged 2 commits intomasterfrom
gg/llama-disable-graph-reuse-with-pp

Conversation

@ggerganov
Copy link
Member

@ggerganov ggerganov commented Mar 12, 2026

The following repro demonstrates the issue:

make -j && ./bin/llama-perplexity -hf ggml-org/Qwen3-0.6B-GGUF -f wiki.test.raw --chunks 16 -ngl 99 -ub 512 -b 2048

PPL = 2382.4719 +/- 246.20903

The problem seems to occur when 2 consecutive pp ubatches both output logits, which occurs in perplexity runs with the above parameters: 4 ubatches of size 512, the second 2 ubtaches output logits. I think the graph reuse logic somehow conflicts with the scheduler's logic for tracking the current copy:

GGML_ASSERT(!sched->is_alloc);
sched->cur_copy = sched->next_copy;
sched->next_copy = (sched->next_copy + 1) % sched->n_copies;
ggml_backend_sched_split_graph(sched, graph);

For now disabling graph reuse when pipeline parallelism is active to workaround. Proper investigation is necessary.


Additionally, after #17795, only disabling the graph reuse is not enough to fix the issue. So for now also reverting that change.

Note that both commits in this PR are needed. Neither one fixes the issue alone.

cc @aendk @gaugarg-nv

@ggerganov ggerganov force-pushed the gg/llama-disable-graph-reuse-with-pp branch from 7bc73ae to dfa3ad1 Compare March 12, 2026 16:05
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Mar 12, 2026
Copy link
Collaborator

@ORippler ORippler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's take the time to understand what's going on here

@ggerganov ggerganov merged commit 57819b8 into master Mar 12, 2026
48 of 76 checks passed
@ggerganov ggerganov deleted the gg/llama-disable-graph-reuse-with-pp branch March 12, 2026 19:50
@Superbobo75
Copy link

hi,

Starting from version c8323 of llama.cpp, where the graph parameter was disabled, the inference speed of the model Unsloth/Qwen3.5-35B-A3B-Q5_K_M.gguf dropped significantly. On my specific hardware setup—consisting of 2x RTX 5060 Ti (16 GB VRAM) and 64 GB DDR4 RAM—the performance fell from approximately 85 tokens per second (t/s) to around 50 t/s.

Due to this substantial regression, I am currently sticking with the last known functional version prior to this change. I have not yet tested other models to see if this issue is isolated to this specific architecture or configuration.

@aendk
Copy link
Contributor

aendk commented Mar 18, 2026

@ggerganov thanks for the repro.
I am currently looking into this.
For preliminary results, I compared the PPL of three versions of the code:

  1. The build of this merged code (b8323) as ground truth
  2. The previous build (b8322)
  3. The previous build (b8322), but with copy_from_host set to false

In my setup (RTX PRO 6000, 4500), I get the following PPLs:

  1. PPL = 28.8008 +/- 1.50705
  2. PPL = 3964.6876 +/- 382.88199
  3. PPL = 28.8008 +/- 1.50705

This matches with my initial hypothesis, and with my concerns I voiced in the initial PR (#17795 (comment)).
My aim was to only relax CPU split -> CUDA split synchronization, as they are inherently synced. This is what I wanted to explicitly check via ggml_backend_is_cpu(backend_src).
Due to API constraints, I relaxed this check to allow for CPU buffer to CUDA buffer copies, in addition to checking ggml_backend_dev_type(backend_src->device) == GGML_BACKEND_DEVICE_TYPE_CPU.
In the pipeline parallelism setting here, the check is not stringent enough. This new path is taken where it shouldn't

The fix for this is therefore to make the copy_from_cpu check more stringent, and to test thoroughly with llama-perplexity before re-proposing this. Alternatively, disabling this if we are in a multi-GPU setting could be a workaround.
See comment below

Regarding the graph reuse logic, I need to dig deeper. According to the PPL results above, just disabling the incorrect part of my PR is enough. Did you test on workstation GPUs too, or were you also using other GPUs (consumer/GeForce or AMD)?

@aendk
Copy link
Contributor

aendk commented Mar 19, 2026

With further analysis, I need to retract my first assumption.

The copy_from_host change was not directly at fault here.
In my testing, pipeline parallelism with 2 GPUs leads to 3 splits: input embedding split on CPU, GPU0, and GPU1.
Between GPU0 and GPU1, activation tensors were copied asynchronously as before, and input vectors like weights were also copied correctly from CPU to GPU2. I saw no incorrect use of the new copy_from_host.

I rather think that the event-based scheduling mechanism implicitly relied on synchronizations in the blocking CPU->GPU copy operations, and that removing them surfaced an already existing bug/blurry edge case in the event-based scheduling mechanism.
To be specific, each CPU-> GPU copy synced twice, a standalone ggml_backend_cuda_synchronize(split_backend), and then a cudaStreamSynchronize(cudaStreamPerThread) in ggml_backend_tensor_copy. With my PR, every copy was completely asynchronous.

I briefly tested the assumption that the event-based scheduling mechanism for pipeline parallelism implicitly requires similar stream synchronizations to the single GPU case.
In my initial PR, we went for the saaasg pattern for single GPU setups (s=sync, a= async copies, g= graph execution)
My final implementation executed this new additional s just before the graph execution (g) in single GPU setups only, and not for pipeline parallelism. The assumption was that pipeline parallelism does not require this.
However, executing this new s (ggml_backend_cuda_synchronize(split_backend)) now also during pipeline parallelism seems to fix the bug. llama-perplexity shows identical values to master and performance on linux is unchanged too.

I will open a new draft PR for further discussions on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants